import copy
from typing import Optional

from moto.stepfunctions.parser.api import HistoryEventType
from moto.stepfunctions.parser.asl.component.common.catch.catch_outcome import (
    CatchOutcome,
)
from moto.stepfunctions.parser.asl.component.common.error_name.failure_event import (
    FailureEvent,
    FailureEventException,
)
from moto.stepfunctions.parser.asl.component.common.parargs import Parargs
from moto.stepfunctions.parser.asl.component.common.retry.retry_outcome import (
    RetryOutcome,
)
from moto.stepfunctions.parser.asl.component.state.exec.execute_state import (
    ExecutionState,
)
from moto.stepfunctions.parser.asl.component.state.exec.state_parallel.branches_decl import (
    BranchesDecl,
)
from moto.stepfunctions.parser.asl.component.state.state_props import StateProps
from moto.stepfunctions.parser.asl.eval.environment import Environment


class StateParallel(ExecutionState):
    # Branches (Required)
    # An array of objects that specify state machines to execute in state_parallel. Each such state
    # machine object must have fields named States and StartAt, whose meanings are exactly
    # like those in the top level of a state machine.
    branches: BranchesDecl
    parargs: Optional[Parargs]

    def __init__(self):
        super().__init__(
            state_entered_event_type=HistoryEventType.ParallelStateEntered,
            state_exited_event_type=HistoryEventType.ParallelStateExited,
        )

    def from_state_props(self, state_props: StateProps) -> None:
        super().from_state_props(state_props)
        self.branches = state_props.get(
            typ=BranchesDecl,
            raise_on_missing=ValueError(
                f"Missing Branches definition in props '{state_props}'."
            ),
        )
        self.parargs = state_props.get(Parargs)

    def _eval_execution(self, env: Environment) -> None:
        env.event_manager.add_event(
            context=env.event_history_context,
            event_type=HistoryEventType.ParallelStateStarted,
        )
        self.branches.eval(env)
        env.event_manager.add_event(
            context=env.event_history_context,
            event_type=HistoryEventType.ParallelStateSucceeded,
            update_source_event_id=False,
        )

    def _eval_state(self, env: Environment) -> None:
        # Initialise the retry counter for execution states.
        env.states.context_object.context_object_data["State"]["RetryCount"] = 0

        # Compute the branches' input: if declared this is the parameters, else the current memory state.
        if self.parargs is not None:
            self.parargs.eval(env=env)
        # In both cases, the inputs are copied by value to the branches, to avoid cross branch state manipulation, and
        # cached to allow them to be resubmitted in case of failure.
        input_value = copy.deepcopy(env.stack.pop())

        # Attempt to evaluate the state's logic through until it's successful, caught, or retries have run out.
        while env.is_running():
            try:
                env.stack.append(input_value)
                self._evaluate_with_timeout(env)
                break
            except FailureEventException as failure_event_ex:
                failure_event: FailureEvent = self._from_error(
                    env=env, ex=failure_event_ex
                )
                error_output = self._construct_error_output_value(
                    failure_event=failure_event
                )
                env.states.set_error_output(error_output)
                env.states.set_result(error_output)

                if self.retry is not None:
                    retry_outcome: RetryOutcome = self._handle_retry(
                        env=env, failure_event=failure_event
                    )
                    if retry_outcome == RetryOutcome.CanRetry:
                        continue

                env.event_manager.add_event(
                    context=env.event_history_context,
                    event_type=HistoryEventType.ParallelStateFailed,
                )

                if self.catch is not None:
                    self._handle_catch(env=env, failure_event=failure_event)
                    catch_outcome: CatchOutcome = env.stack[-1]
                    if catch_outcome == CatchOutcome.Caught:
                        break

                self._handle_uncaught(env=env, failure_event=failure_event)
